{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "7765UFHoyGx6"
      },
      "source": [
        "##### Copyright 2020 The TensorFlow Authors."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "KsOkK8O69PyT"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ZS8z-_KeywY9"
      },
      "source": [
        "# TF Lattice Canned Estimator"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "r61fkA2i9Y3_"
      },
      "source": [
        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
        "  <td><a target=\"_blank\" href=\"https://tensorflow.google.cn/lattice/tutorials/canned_estimators\"><img src=\"https://tensorflow.google.cn/images/tf_logo_32px.png\">在 TensorFlow.org 上查看 </a></td>\n",
        "  <td><a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/docs-l10n/blob/master/site/zh-cn/lattice/tutorials/canned_estimators.ipynb\"><img src=\"https://tensorflow.google.cn/images/colab_logo_32px.png\">Run in Google Colab</a></td>\n",
        "  <td><a target=\"_blank\" href=\"https://github.com/tensorflow/docs-l10n/blob/master/site/zh-cn/lattice/tutorials/canned_estimators.ipynb\"><img src=\"https://tensorflow.google.cn/images/GitHub-Mark-32px.png\">在 GitHub 中查看源代码</a></td>\n",
        "  <td><a href=\"https://storage.googleapis.com/tensorflow_docs/docs-l10n/site/zh-cn/lattice/tutorials/canned_estimators.ipynb\"><img src=\"https://tensorflow.google.cn/images/download_logo_32px.png\"> 下载笔记本</a></td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "WCpl-9WDVq9d"
      },
      "source": [
        "## 概述\n",
        "\n",
        "利用 Canned Estimator，您可以快速轻松地针对典型用例训练 TFL 模型。本指南概述了创建 TFL Canned Estimator 所需的步骤。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "x769lI12IZXB"
      },
      "source": [
        "## 设置"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "fbBVAR6UeRN5"
      },
      "source": [
        "安装 TF Lattice 软件包："
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "bpXjJKpSd3j4"
      },
      "outputs": [],
      "source": [
        "#@test {\"skip\": true}\n",
        "!pip install tensorflow-lattice"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "jSVl9SHTeSGX"
      },
      "source": [
        "导入所需的软件包："
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "both",
        "id": "FbZDk8bIx8ig"
      },
      "outputs": [],
      "source": [
        "import tensorflow as tf\n",
        "\n",
        "import logging\n",
        "import numpy as np\n",
        "import pandas as pd\n",
        "import sys\n",
        "import tensorflow_lattice as tfl\n",
        "from tensorflow import feature_column as fc\n",
        "logging.disable(sys.maxsize)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "svPuM6QNxlrH"
      },
      "source": [
        "下载 UCI Statlog (Heart) 数据集："
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "both",
        "id": "j-k1qTR_yvBl"
      },
      "outputs": [],
      "source": [
        "csv_file = tf.keras.utils.get_file(\n",
        "    'heart.csv', 'http://storage.googleapis.com/applied-dl/heart.csv')\n",
        "df = pd.read_csv(csv_file)\n",
        "target = df.pop('target')\n",
        "train_size = int(len(df) * 0.8)\n",
        "train_x = df[:train_size]\n",
        "train_y = target[:train_size]\n",
        "test_x = df[train_size:]\n",
        "test_y = target[train_size:]\n",
        "df.head()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "nKkAw12SxvGG"
      },
      "source": [
        "设置用于在本指南中进行训练的默认值："
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "both",
        "id": "1T6GFI9F6mcG"
      },
      "outputs": [],
      "source": [
        "LEARNING_RATE = 0.01\n",
        "BATCH_SIZE = 128\n",
        "NUM_EPOCHS = 500\n",
        "PREFITTING_NUM_EPOCHS = 10"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0TGfzhPHzpix"
      },
      "source": [
        "## 特征列\n",
        "\n",
        "与任何其他 TF Estimator 一样，数据通常需要通过 input_fn 传递给 Estimator，并使用 [FeatureColumns](https://tensorflow.google.cn/guide/feature_columns) 进行解析。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "DCIUz8apzs0l"
      },
      "outputs": [],
      "source": [
        "# Feature columns.\n",
        "# - age\n",
        "# - sex\n",
        "# - cp        chest pain type (4 values)\n",
        "# - trestbps  resting blood pressure\n",
        "# - chol      serum cholestoral in mg/dl\n",
        "# - fbs       fasting blood sugar &gt; 120 mg/dl\n",
        "# - restecg   resting electrocardiographic results (values 0,1,2)\n",
        "# - thalach   maximum heart rate achieved\n",
        "# - exang     exercise induced angina\n",
        "# - oldpeak   ST depression induced by exercise relative to rest\n",
        "# - slope     the slope of the peak exercise ST segment\n",
        "# - ca        number of major vessels (0-3) colored by flourosopy\n",
        "# - thal      3 = normal; 6 = fixed defect; 7 = reversable defect\n",
        "feature_columns = [\n",
        "    fc.numeric_column('age', default_value=-1),\n",
        "    fc.categorical_column_with_vocabulary_list('sex', [0, 1]),\n",
        "    fc.numeric_column('cp'),\n",
        "    fc.numeric_column('trestbps', default_value=-1),\n",
        "    fc.numeric_column('chol'),\n",
        "    fc.categorical_column_with_vocabulary_list('fbs', [0, 1]),\n",
        "    fc.categorical_column_with_vocabulary_list('restecg', [0, 1, 2]),\n",
        "    fc.numeric_column('thalach'),\n",
        "    fc.categorical_column_with_vocabulary_list('exang', [0, 1]),\n",
        "    fc.numeric_column('oldpeak'),\n",
        "    fc.categorical_column_with_vocabulary_list('slope', [0, 1, 2]),\n",
        "    fc.numeric_column('ca'),\n",
        "    fc.categorical_column_with_vocabulary_list(\n",
        "        'thal', ['normal', 'fixed', 'reversible']),\n",
        "]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "hEZstmtT2CA3"
      },
      "source": [
        "TFL Canned Estimator 使用特征列的类型来决定要使用哪种类型的校准层。对于数字特征列，我们使用 `tfl.layers.PWLCalibration` 层；对于分类特征列，我们使用 `tfl.layers.CategoricalCalibration` 层。\n",
        "\n",
        "请注意，分类特征列不是由嵌入特征列包装。它们直接馈送到 Estimator 中。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "H_LoW_9m5OFL"
      },
      "source": [
        "## 创建 input_fn\n",
        "\n",
        "与其他任何 Estimator 一样，您可以使用 input_fn 将数据馈送给模型进行训练和评估。TFL Estimator 可以自动计算特征的分位数，并将其用作 PWL 校准层的输入关键点。为此，它们需要传递 `feature_analysis_input_fn`，此函数与训练 input_fn 类似，但使用单个周期或数据的子样本。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "lFVy1Efy5NKD"
      },
      "outputs": [],
      "source": [
        "train_input_fn = tf.compat.v1.estimator.inputs.pandas_input_fn(\n",
        "    x=train_x,\n",
        "    y=train_y,\n",
        "    shuffle=False,\n",
        "    batch_size=BATCH_SIZE,\n",
        "    num_epochs=NUM_EPOCHS,\n",
        "    num_threads=1)\n",
        "\n",
        "# feature_analysis_input_fn is used to collect statistics about the input.\n",
        "feature_analysis_input_fn = tf.compat.v1.estimator.inputs.pandas_input_fn(\n",
        "    x=train_x,\n",
        "    y=train_y,\n",
        "    shuffle=False,\n",
        "    batch_size=BATCH_SIZE,\n",
        "    # Note that we only need one pass over the data.\n",
        "    num_epochs=1,\n",
        "    num_threads=1)\n",
        "\n",
        "test_input_fn = tf.compat.v1.estimator.inputs.pandas_input_fn(\n",
        "    x=test_x,\n",
        "    y=test_y,\n",
        "    shuffle=False,\n",
        "    batch_size=BATCH_SIZE,\n",
        "    num_epochs=1,\n",
        "    num_threads=1)\n",
        "\n",
        "# Serving input fn is used to create saved models.\n",
        "serving_input_fn = (\n",
        "    tf.estimator.export.build_parsing_serving_input_receiver_fn(\n",
        "        feature_spec=fc.make_parse_example_spec(feature_columns)))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "uQlzREcm2Wbj"
      },
      "source": [
        "## 特征配置\n",
        "\n",
        "使用 <code>tfl.configs.FeatureConfig</code> 设置特征校准和按特征的配置。特征配置包括单调性约束、按特征的正则化（请参阅 <code>tfl.configs.RegularizerConfig</code>）以及点阵模型的点阵大小。\n",
        "\n",
        "如果没有为输入特征定义配置，则使用 `tfl.config.FeatureConfig` 中的默认配置。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "vD0tNpiO3p9c"
      },
      "outputs": [],
      "source": [
        "# Feature configs are used to specify how each feature is calibrated and used.\n",
        "feature_configs = [\n",
        "    tfl.configs.FeatureConfig(\n",
        "        name='age',\n",
        "        lattice_size=3,\n",
        "        # By default, input keypoints of pwl are quantiles of the feature.\n",
        "        pwl_calibration_num_keypoints=5,\n",
        "        monotonicity='increasing',\n",
        "        pwl_calibration_clip_max=100,\n",
        "        # Per feature regularization.\n",
        "        regularizer_configs=[\n",
        "            tfl.configs.RegularizerConfig(name='calib_wrinkle', l2=0.1),\n",
        "        ],\n",
        "    ),\n",
        "    tfl.configs.FeatureConfig(\n",
        "        name='cp',\n",
        "        pwl_calibration_num_keypoints=4,\n",
        "        # Keypoints can be uniformly spaced.\n",
        "        pwl_calibration_input_keypoints='uniform',\n",
        "        monotonicity='increasing',\n",
        "    ),\n",
        "    tfl.configs.FeatureConfig(\n",
        "        name='chol',\n",
        "        # Explicit input keypoint initialization.\n",
        "        pwl_calibration_input_keypoints=[126.0, 210.0, 247.0, 286.0, 564.0],\n",
        "        monotonicity='increasing',\n",
        "        # Calibration can be forced to span the full output range by clamping.\n",
        "        pwl_calibration_clamp_min=True,\n",
        "        pwl_calibration_clamp_max=True,\n",
        "        # Per feature regularization.\n",
        "        regularizer_configs=[\n",
        "            tfl.configs.RegularizerConfig(name='calib_hessian', l2=1e-4),\n",
        "        ],\n",
        "    ),\n",
        "    tfl.configs.FeatureConfig(\n",
        "        name='fbs',\n",
        "        # Partial monotonicity: output(0) &lt;= output(1)\n",
        "        monotonicity=[(0, 1)],\n",
        "    ),\n",
        "    tfl.configs.FeatureConfig(\n",
        "        name='trestbps',\n",
        "        pwl_calibration_num_keypoints=5,\n",
        "        monotonicity='decreasing',\n",
        "    ),\n",
        "    tfl.configs.FeatureConfig(\n",
        "        name='thalach',\n",
        "        pwl_calibration_num_keypoints=5,\n",
        "        monotonicity='decreasing',\n",
        "    ),\n",
        "    tfl.configs.FeatureConfig(\n",
        "        name='restecg',\n",
        "        # Partial monotonicity: output(0) &lt;= output(1), output(0) &lt;= output(2)\n",
        "        monotonicity=[(0, 1), (0, 2)],\n",
        "    ),\n",
        "    tfl.configs.FeatureConfig(\n",
        "        name='exang',\n",
        "        # Partial monotonicity: output(0) &lt;= output(1)\n",
        "        monotonicity=[(0, 1)],\n",
        "    ),\n",
        "    tfl.configs.FeatureConfig(\n",
        "        name='oldpeak',\n",
        "        pwl_calibration_num_keypoints=5,\n",
        "        monotonicity='increasing',\n",
        "    ),\n",
        "    tfl.configs.FeatureConfig(\n",
        "        name='slope',\n",
        "        # Partial monotonicity: output(0) &lt;= output(1), output(1) &lt;= output(2)\n",
        "        monotonicity=[(0, 1), (1, 2)],\n",
        "    ),\n",
        "    tfl.configs.FeatureConfig(\n",
        "        name='ca',\n",
        "        pwl_calibration_num_keypoints=4,\n",
        "        monotonicity='increasing',\n",
        "    ),\n",
        "    tfl.configs.FeatureConfig(\n",
        "        name='thal',\n",
        "        # Partial monotonicity:\n",
        "        # output(normal) &lt;= output(fixed)\n",
        "        # output(normal) &lt;= output(reversible)        \n",
        "        monotonicity=[('normal', 'fixed'), ('normal', 'reversible')],\n",
        "    ),\n",
        "]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "LKBULveZ4mr3"
      },
      "source": [
        "## 校准线性模型\n",
        "\n",
        "要构造 TFL Canned Estimator，请从 `tfl.configs` 构造模型配置。使用 `tfl.configs.CalibratedLinearConfig` 构造校准线性模型。此模型会将分段线性和分类校准应用于输入特征，随后应用线性组合和可选的输出分段线性校准。使用输出校准或指定输出边界时，线性层会将加权平均应用于校准的输入。\n",
        "\n",
        "下面的示例将基于前 5 个特征创建一个校准线性模型。我们利用 `tfl.visualization` 绘制带校准器图的模型计算图。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "diRRozio4sAL"
      },
      "outputs": [],
      "source": [
        "# Model config defines the model structure for the estimator.\n",
        "model_config = tfl.configs.CalibratedLinearConfig(\n",
        "    feature_configs=feature_configs,\n",
        "    use_bias=True,\n",
        "    output_calibration=True,\n",
        "    regularizer_configs=[\n",
        "        # Regularizer for the output calibrator.\n",
        "        tfl.configs.RegularizerConfig(name='output_calib_hessian', l2=1e-4),\n",
        "    ])\n",
        "# A CannedClassifier is constructed from the given model config.\n",
        "estimator = tfl.estimators.CannedClassifier(\n",
        "    feature_columns=feature_columns[:5],\n",
        "    model_config=model_config,\n",
        "    feature_analysis_input_fn=feature_analysis_input_fn,\n",
        "    optimizer=tf.keras.optimizers.Adam(LEARNING_RATE),\n",
        "    config=tf.estimator.RunConfig(tf_random_seed=42))\n",
        "estimator.train(input_fn=train_input_fn)\n",
        "results = estimator.evaluate(input_fn=test_input_fn)\n",
        "print('Calibrated linear test AUC: {}'.format(results['auc']))\n",
        "saved_model_path = estimator.export_saved_model(estimator.model_dir,\n",
        "                                                serving_input_fn)\n",
        "model_graph = tfl.estimators.get_model_graph(saved_model_path)\n",
        "tfl.visualization.draw_model_graph(model_graph)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "zWzPM2_p977t"
      },
      "source": [
        "## 校准点阵模型\n",
        "\n",
        "使用 `tfl.configs.CalibratedLatticeConfig` 构造校准点阵模型。校准点阵模型会将分段线性和分类校准应用于输入特征，随后应用点阵模型和可选的输出分段线性校准。\n",
        "\n",
        "下面的示例将基于前 5 个特征创建一个校准点阵模型。\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "C6EvVpKW4BbC"
      },
      "outputs": [],
      "source": [
        "# This is calibrated lattice model: Inputs are calibrated, then combined\n",
        "# non-linearly using a lattice layer.\n",
        "model_config = tfl.configs.CalibratedLatticeConfig(\n",
        "    feature_configs=feature_configs,\n",
        "    regularizer_configs=[\n",
        "        # Torsion regularizer applied to the lattice to make it more linear.\n",
        "        tfl.configs.RegularizerConfig(name='torsion', l2=1e-4),\n",
        "        # Globally defined calibration regularizer is applied to all features.\n",
        "        tfl.configs.RegularizerConfig(name='calib_hessian', l2=1e-4),\n",
        "    ])\n",
        "# A CannedClassifier is constructed from the given model config.\n",
        "estimator = tfl.estimators.CannedClassifier(\n",
        "    feature_columns=feature_columns[:5],\n",
        "    model_config=model_config,\n",
        "    feature_analysis_input_fn=feature_analysis_input_fn,\n",
        "    optimizer=tf.keras.optimizers.Adam(LEARNING_RATE),\n",
        "    config=tf.estimator.RunConfig(tf_random_seed=42))\n",
        "estimator.train(input_fn=train_input_fn)\n",
        "results = estimator.evaluate(input_fn=test_input_fn)\n",
        "print('Calibrated lattice test AUC: {}'.format(results['auc']))\n",
        "saved_model_path = estimator.export_saved_model(estimator.model_dir,\n",
        "                                                serving_input_fn)\n",
        "model_graph = tfl.estimators.get_model_graph(saved_model_path)\n",
        "tfl.visualization.draw_model_graph(model_graph)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "9494K_ZBKFcm"
      },
      "source": [
        "## 校准点阵集成\n",
        "\n",
        "当特征数量很大时，可以使用集成模型，这种模型会为特征的子集创建多个较小的点阵并计算它们的输出平均值，而不是仅创建单个巨大的点阵。使用 `tfl.configs.CalibratedLatticeEnsembleConfig` 构造集成点阵模型。校准点阵集成模型会将分段线性和分类校准应用于输入特征，随后应用点阵模型的集成和可选的输出分段线性校准。\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "KjrzziMFKuCB"
      },
      "source": [
        "### 随机点阵集成\n",
        "\n",
        "以下模型配置为每个点阵使用特征的随机子集。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "YBSS7dLjKExq"
      },
      "outputs": [],
      "source": [
        "# This is random lattice ensemble model with separate calibration:\n",
        "# model output is the average output of separatly calibrated lattices.\n",
        "model_config = tfl.configs.CalibratedLatticeEnsembleConfig(\n",
        "    feature_configs=feature_configs,\n",
        "    num_lattices=5,\n",
        "    lattice_rank=3)\n",
        "# A CannedClassifier is constructed from the given model config.\n",
        "estimator = tfl.estimators.CannedClassifier(\n",
        "    feature_columns=feature_columns,\n",
        "    model_config=model_config,\n",
        "    feature_analysis_input_fn=feature_analysis_input_fn,\n",
        "    optimizer=tf.keras.optimizers.Adam(LEARNING_RATE),\n",
        "    config=tf.estimator.RunConfig(tf_random_seed=42))\n",
        "estimator.train(input_fn=train_input_fn)\n",
        "results = estimator.evaluate(input_fn=test_input_fn)\n",
        "print('Random ensemble test AUC: {}'.format(results['auc']))\n",
        "saved_model_path = estimator.export_saved_model(estimator.model_dir,\n",
        "                                                serving_input_fn)\n",
        "model_graph = tfl.estimators.get_model_graph(saved_model_path)\n",
        "tfl.visualization.draw_model_graph(model_graph, calibrator_dpi=15)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "7uyO8s97FGJM"
      },
      "source": [
        "### RTL 层随机点阵集成\n",
        "\n",
        "以下模型配置使用 `tfl.layers.RTL` 层，该层为每个点阵使用特征的随机子集。我们注意到，`tfl.layers.RTL` 仅支持单调性约束，对于所有特征都必须具有相同的点阵大小，并且不包含按特征的正则化。请注意，与使用单独的 `tfl.layers.Lattice` 实例相比，使用 `tfl.layers.RTL` 层可使您扩展到更大的集成。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "8v7dKg-FF7iz"
      },
      "outputs": [],
      "source": [
        "# Make sure our feature configs have the same lattice size, no per-feature\n",
        "# regularization, and only monotonicity constraints.\n",
        "rtl_layer_feature_configs = copy.deepcopy(feature_configs)\n",
        "for feature_config in rtl_layer_feature_configs:\n",
        "  feature_config.lattice_size = 2\n",
        "  feature_config.unimodality = 'none'\n",
        "  feature_config.reflects_trust_in = None\n",
        "  feature_config.dominates = None\n",
        "  feature_config.regularizer_configs = None\n",
        "# This is RTL layer ensemble model with separate calibration:\n",
        "# model output is the average output of separately calibrated lattices.\n",
        "model_config = tfl.configs.CalibratedLatticeEnsembleConfig(\n",
        "    lattices='rtl_layer',\n",
        "    feature_configs=rtl_layer_feature_configs,\n",
        "    num_lattices=5,\n",
        "    lattice_rank=3)\n",
        "# A CannedClassifier is constructed from the given model config.\n",
        "estimator = tfl.estimators.CannedClassifier(\n",
        "    feature_columns=feature_columns,\n",
        "    model_config=model_config,\n",
        "    feature_analysis_input_fn=feature_analysis_input_fn,\n",
        "    optimizer=tf.keras.optimizers.Adam(LEARNING_RATE),\n",
        "    config=tf.estimator.RunConfig(tf_random_seed=42))\n",
        "estimator.train(input_fn=train_input_fn)\n",
        "results = estimator.evaluate(input_fn=test_input_fn)\n",
        "print('Random ensemble test AUC: {}'.format(results['auc']))\n",
        "saved_model_path = estimator.export_saved_model(estimator.model_dir,\n",
        "                                                serving_input_fn)\n",
        "model_graph = tfl.estimators.get_model_graph(saved_model_path)\n",
        "tfl.visualization.draw_model_graph(model_graph, calibrator_dpi=15)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "LSXEaYAULRvf"
      },
      "source": [
        "### 晶体点阵集成\n",
        "\n",
        "TFL 还提供了一种称为[晶体](https://papers.nips.cc/paper/6377-fast-and-flexible-monotonic-functions-with-ensembles-of-lattices)的启发式特征排列算法。晶体算法将先训练一个*预拟合模型*，此模型会估算成对的特征交互。然后，它将对最终集成进行排列，使具有更多非线性交互的特征处于同一点阵中。\n",
        "\n",
        "对于晶体模型，您还需要提供用于训练上述预拟合模型的 `prefitting_input_fn`。拟合模型不需要完全训练，因此几个周期便已足够。\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "FjQKh9saMaFu"
      },
      "outputs": [],
      "source": [
        "prefitting_input_fn = tf.compat.v1.estimator.inputs.pandas_input_fn(\n",
        "    x=train_x,\n",
        "    y=train_y,\n",
        "    shuffle=False,\n",
        "    batch_size=BATCH_SIZE,\n",
        "    num_epochs=PREFITTING_NUM_EPOCHS,\n",
        "    num_threads=1)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "fVnZpwX8MtPi"
      },
      "source": [
        "然后，您可以通过在模型配置中设置 `lattice='crystals'` 来创建晶体模型。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "f4awRMDe-eMv"
      },
      "outputs": [],
      "source": [
        "# This is Crystals ensemble model with separate calibration: model output is\n",
        "# the average output of separatly calibrated lattices.\n",
        "model_config = tfl.configs.CalibratedLatticeEnsembleConfig(\n",
        "    feature_configs=feature_configs,\n",
        "    lattices='crystals',\n",
        "    num_lattices=5,\n",
        "    lattice_rank=3)\n",
        "# A CannedClassifier is constructed from the given model config.\n",
        "estimator = tfl.estimators.CannedClassifier(\n",
        "    feature_columns=feature_columns,\n",
        "    model_config=model_config,\n",
        "    feature_analysis_input_fn=feature_analysis_input_fn,\n",
        "    # prefitting_input_fn is required to train the prefitting model.\n",
        "    prefitting_input_fn=prefitting_input_fn,\n",
        "    optimizer=tf.keras.optimizers.Adam(LEARNING_RATE),\n",
        "    prefitting_optimizer=tf.keras.optimizers.Adam(LEARNING_RATE),\n",
        "    config=tf.estimator.RunConfig(tf_random_seed=42))\n",
        "estimator.train(input_fn=train_input_fn)\n",
        "results = estimator.evaluate(input_fn=test_input_fn)\n",
        "print('Crystals ensemble test AUC: {}'.format(results['auc']))\n",
        "saved_model_path = estimator.export_saved_model(estimator.model_dir,\n",
        "                                                serving_input_fn)\n",
        "model_graph = tfl.estimators.get_model_graph(saved_model_path)\n",
        "tfl.visualization.draw_model_graph(model_graph, calibrator_dpi=15)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Isb2vyLAVBM1"
      },
      "source": [
        "您可以使用 `tfl.visualization` 模块绘制包含更多细节的特征校准器。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "DJPaREuWS2sg"
      },
      "outputs": [],
      "source": [
        "_ = tfl.visualization.plot_feature_calibrator(model_graph, \"age\")\n",
        "_ = tfl.visualization.plot_feature_calibrator(model_graph, \"restecg\")"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "name": "canned_estimators.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
